線形回帰(No library)
特徴量が1次元の場合(原点を通る)
計算式
https://gyazo.com/ef45b49dac030bb980d26983d3e54a26
Coding
code: Python
import numpy as np
import matplotlib.pyplot as plt
def reg1dim1(x, y):
return np.dot(x, y) / (x**2).sum()
a = reg1dim1(x, y)
plt.scatter(x, y, color='k')
x_max = x.max()
y_max = a * x_max
plt.title('y = {:.2f}x'.format(a))
plt.show()
https://gyazo.com/7ff8b63c35b6b8f0931c78aa6bdc206d
特徴量が1次元の場合(一般)
計算式
https://gyazo.com/919730cbfd49ba4fd20b60638c28079chttps://gyazo.com/9bca0af0bbe613a8a84842315c186845
Coding
code: Python
import numpy as np
import matplotlib.pyplot as plt
def reg1dim2(x, y):
n = len(x)
a = (np.dot(x, y) - (x.sum() * y.sum() / n)) / (np.dot(x, x) - (x.sum()**2 / n))
b = (y - a * x).sum() / n
return a, b
a, b = reg1dim2(x, y)
plt.scatter(x, y, color='k')
x_max = x.max()
y_max = a * x.max() + b
plt.title('y = {a}x + {b}'.format(a=a, b=b))
plt.show()
https://gyazo.com/7fa83b62a7e45dc6f46f20e24885971c
特徴量が多次元の場合
計算式
https://gyazo.com/a2871a877f08e2c4487189aa0b7d0e78
Coding
code: Python
import numpy as np
from scipy import linalg
class LinearRegression:
def __init__(self):
self.w_ = None
def fit(self, X, t):
"""訓練データによる学習を行う.
X: 入力訓練データ
t: 出力訓練データ
"""
# np.c_で行に対して結合する(横に結合)
# 行列Xの左に要素1からなる列を1つ加えたもの
Xtil = np.c_[np.ones(X.shape0), X] A = np.dot(Xtil.T, Xtil)
b = np.dot(Xtil.T, t)
# 計算結果を格納
self.w_ = linalg.solve(A, b)
def predict(self, X):
if X.ndim == 1:
X = X.reshape(1, -1)
Xtil = np.c_[np.ones(X.shape0), X] return np.dot(Xtil, self.w_)
code: Python
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
n = 100
scale = 10
np.random.seed(0)
X = np.random.random((n, 2)) * scale
w0 = 1
w1 = 2
w2 = 3
# 線形和に乱数を足したもの
y = w0 + w1 * X:, 0 + w2 * X:, 1 + np.random.randn(n) model = LinearRegression()
model.fit(X, y)
print("係数:", model.w_)
print("(1, 1)に対する予測値:", model.predict(np.array(1, 1))) # 以下、可視化の処理
xmesh, ymesh = np.meshgrid(np.linspace(0, scale, 20), np.linspace(0, scale, 20))
zmesh = (model.w_0 + model.w_1 * xmesh.ravel() + model.w_2 * ymesh.ravel()).reshape(xmesh.shape) fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(xmesh, ymesh, zmesh, color="r")
plt.show()
--------------------------------------------------------------------------
--------------------------------------------------------------------------
https://gyazo.com/bd96309b555c430421dccbfa8e1d49e3